# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluate several metrics for Wasserstein transformation."""

import collections

from absl import app
from absl import flags
import numpy as np
import pandas as pd
from scipy.spatial import distance
from six.moves import range
import six.moves.cPickle as pickle
from sklearn import ensemble
from sklearn import metrics
import sklearn.linear_model as lm
from tensorflow.compat.v1 import gfile

from correct_batch_effects_wdn import distance as distance_analysis
from correct_batch_effects_wdn import evaluate
from correct_batch_effects_wdn import io_utils
from correct_batch_effects_wdn import ljosa_preprocessing
from correct_batch_effects_wdn import metadata
from correct_batch_effects_wdn import transform

FLAGS = flags.FLAGS

flags.DEFINE_string("input_df", None, "Path to the embedding dataframe.")
flags.DEFINE_string("transformation_file", None,
                    "Path to the Wasserstein transformation file.")
flags.DEFINE_string("output_file", None, "Where to save evaluation results.")
flags.DEFINE_integer("num_bootstrap", 2, "number of bootstrap samples to use")
flags.DEFINE_bool("percentile_normalize", False, "Normalize 1-99 percentile to "
                  "match controls.")
flags.DEFINE_bool("factor_analysis", False, "Apply factor analysis")
flags.DEFINE_integer("num_steps", 10, "number of steps to use")

SEED = 1234


def load_contents(file_path):
  """Load contents from pickle file.

  Args:
    file_path (string): location of file to load.

  Returns:
    contents: contents from pickle file.
  """
  with gfile.GFile(file_path, mode="r") as f:
    contents = f.read()
    contents = pickle.loads(contents)
  return contents


def apply_post_processing(df, percent_norm, factor_analys):
  """Apply percentile normalization and/or factor analysis.

  If using WDN, this must be done after the transformation. Thus, this function
  is applied in wasserstein_transform.

  Args:
    df (pandas dataframe): input dataframe
    percent_norm (bool): whether to apply percentile normalization
    factor_analys (bool): whether to apply factor analysis
  Returns:
   df (pandas dataframe): post-processed dataframe.
  """
  if percent_norm:
    df = ljosa_preprocessing.normalize_ljosa(df)
  if factor_analys:
    np.random.seed(0)
    df = transform.factor_analysis(df, 0.15, 50)
  return df


def wasserstein_transform(contents, emb_df, time_step, level=metadata.BATCH):
  """Apply Wasserstein transform to dataframe.

  Args:
    contents (dict): contents generated by Wasserstein network.
    emb_df (pandas dataframe): dataframe to apply transformation to
    time_step (int): which timestep to use
    level (string): at which level to apply Wasserstein transformation. This
      should match the one used in training, e.g. batch or plate

  Returns:
    emb_df_trans (pandas dataframe): transformed pandas dataframe.
  """
  b_val, w_val = contents[time_step]["b_val"], contents[time_step]["w_val"]
  emb_df_trans = []
  for label, emb_df_level in emb_df.groupby(level=level):
    emb_df_trans.append(
        emb_df_level.dot(w_val[(str(label),)]) + b_val[(str(label),)])
  emb_df_trans = pd.concat(emb_df_trans)
  return emb_df_trans


def filter_comp(compound_pairs, left_comp):
  """Filters correct and mismatch by left compound.

  Args:
    compound_pairs (list): pairs of compounds
    left_comp (string): compound to be left out

  Returns:
    filtered (list): pairs of filtered compounds
  """
  filtered = []
  for pair in compound_pairs:
    if pair[0][metadata.COMPOUND] == left_comp:
      filtered.append(pair)
  return filtered


def transform_and_means(contents, emb_df, step, linear=True,
                        percent_norm=False,
                        factor_analys=False,
                        drop_controls=True):
  """Apply Wasserstein and take means.

  This functions takes the Wasserstein transform and then means across desired
  categories. It takes advantage when the transformation is linear to
  significantly speed things up, by taking the means first and then applying the
  transformation.

  If we apply the percentile normalization or factor analysis, we cannot do the
  means-first trick.

  Args:
    contents (dict): Transformation file.
    emb_df (pandas dataframe): Dataframe to transform and take means of.
    step (int): step where to take Wasserstein transform.
    linear (bool): Whether or not transformation is linear.
    percent_norm (bool): whether to apply percentile normalization
    factor_analys (bool): whether to apply factor analysis
    drop_controls (bool): whether or not to drop controls

  Returns:
    transformed_means (pandas dataframe): dataframe after transform and
      taking the mean.
  """
  if "treatment_group" not in emb_df.index.names:
    raise ValueError("Must have treatment_group in embeddings index names.")
  if linear and not percent_norm and not factor_analys:
    means = emb_df.groupby(level=[
        metadata.MOA, metadata.COMPOUND, metadata.CONCENTRATION, metadata.BATCH,
        metadata.TREATMENT_GROUP
    ]).mean()
    transformed_means = wasserstein_transform(contents, means, step)

  else:
    emb_df_trans = wasserstein_transform(contents, emb_df, step)
    df_post_processed = apply_post_processing(emb_df_trans, percent_norm,
                                              factor_analys)
    transformed_means = df_post_processed.groupby(level=[
        metadata.MOA, metadata.COMPOUND, metadata.CONCENTRATION, metadata.BATCH,
        metadata.TREATMENT_GROUP
    ]).mean()
  if drop_controls:
    return transform.drop_unevaluated_comp(transformed_means)
  else:
    return transformed_means


def find_time_step_max(emb_df_train, contents, steps):
  """Find time step where training accuracy is maximized.

  Notice we do NOT apply post-processing of embeddings (e.g. normalization or
  factor analysis) at this stage.

  Args:
    emb_df_train (pandas dataframe): Dataframe with left-one-out compound.
      Currently this should not include the controls.
    contents (dict): Contents from Wasserstein training routine
    steps (list): Steps for training

  Returns:
    time_step_max (int): Time step where average nsc and nscb for k=1...4
      is maximized.
  """
  if "treatment_group" not in emb_df_train.index.names:
    raise ValueError("Must have treatment_group in embeddings index names.")
  acc_nsc = []
  acc_nsc_nsb = []
  for time_step in sorted(steps):
    ## Right now, we are NOT applying post-processing for cross-validation.
    ## This could be done by setting the values below to the same as the FLAG
    ## values. However, this would be slower, and also (for factor analysis)
    ## we would have to include the controls in emb_df_train, and drop them
    ## later. The reason for this is that they are used to determine the
    ## factor analysis transformation.
    means = transform_and_means(contents, emb_df_train, time_step,
                                percent_norm=False, factor_analys=False)
    df_moa = evaluate.make_knn_moa_dataframe(means)
    acc_nsc.append(df_moa[["Accuracy NSC"]])
    acc_nsc_nsb.append(df_moa[["Accuracy NSC NSB"]])
  acc_nsc = np.concatenate(acc_nsc, axis=1)
  acc_nsc_nsb = np.concatenate(acc_nsc_nsb, axis=1)
  # select time step by average accuracy
  time_step_max = steps[np.argmax(
      np.mean(acc_nsc, 0) + np.mean(acc_nsc_nsb, 0))]
  return time_step_max


def get_index_for_name(df, name):
  """Find index in dataframe with a given name.

  Args:
    df (pandas dataframe): input dataframe
    name (string): Name of desired index
  Returns:
    name_idx (int): Index of name
  """
  found_arr = [i for i, it_name in enumerate(df.index.names)
               if it_name == name]
  if not found_arr:  ## zero length
    raise ValueError("No index with name %s found." %name)
  if len(found_arr) > 1:
    raise ValueError("Found more than 1 index with name %s." %name)
  return found_arr[0]


def confusion_matrix_from_dist(dist, k, which_filter,
                               match_metadata_values,
                               match_metadata=metadata.MOA):
  """Compute confusion matrix between MOAs.

  Diagonal elements correspond to correct MOA predictions, while off-diagonal
  entries correspond to incorrectly identified treatments.

  Args:
    dist (pandas dataframe): cosine distance dataframe between treatmets.
    k (int): number of nearest neighbors to use.
    which_filter (function): filter used to exclude treatments, e.g.
      evaluate.not_same_compound_filter or
      evaluate.not_same_compound_or_batch_filter
    match_metadata_values: ordered list of metadata values to use.
    match_metadata: String, metadata that we would like to match.

  Returns:
    confusion_matrix (2-dimensional numpy array): confusion matrix among MOAs.

  """
  (correct, mismatch) = evaluate.k_nearest_neighbors(dist, k, which_filter)
  confusion_matrix = evaluate.get_confusion_matrix(
      correct, mismatch, match_metadata=match_metadata,
      match_metadata_values=match_metadata_values)
  return confusion_matrix


def update_stats_new_compound(comp_set, dist, k, which_filter, correct_list,
                              mismatch_list, match_metadata_values,
                              confusion_matrix=None):
  """Update list of correct/mismatch and confusion matrix for a new compound.

  This is a helper function for cross_val_train to make adding the stats for
  each compound cleaner.

  Args:
    comp_set (dict): maps "a" to remaining compounds and "b" to left-out
      compounds.
    dist (pandas dataframe): cosine distance dataframe to use.
    k (int): number of nearest neighbors.
    which_filter (function): Type of filter to use.
    correct_list (list): correct compounds to update.
    mismatch_list (list): mismatched compounds to update.
    match_metadata_values (iterable): sorted list of MOAs.
    confusion_matrix (numpy array): confusion matrix to update.
  """
  correct, mismatch = evaluate.k_nearest_neighbors(
      dist, k, which_filter)
  correct_filtered = filter_comp(correct, comp_set["b"])
  mismatch_filtered = filter_comp(mismatch, comp_set["b"])
  correct_list[k].extend(correct_filtered)
  mismatch_list[k].extend(mismatch_filtered)

  if confusion_matrix is not None:
    if correct_filtered or mismatch_filtered:  ## avoid issues when empty
      confusion_matrix[k] += evaluate.get_confusion_matrix(
          correct_filtered, mismatch_filtered,
          match_metadata_values=match_metadata_values,
          match_metadata=metadata.MOA)
  return


def cross_val_train(emb_df_clean, contents, steps, list_of_comp_set, n_comp,
                    report_confusion_matrix=True, percent_norm=False,
                    factor_analys=False):
  """Cross validation to find stopping time with each left-one-out compound.

  Args:
    emb_df_clean (pandas dataframe): embeddings WITH unevaluated compounds.
    contents (dict): Contents from Wasserstein training routine
    steps (list): Steps for training
    list_of_comp_set (list): dictionaries for each compound for leave-one-out
    n_comp (int): number of compounds
    report_confusion_matrix (bool): whether or not to include confusion matrix.
    percent_norm (bool): whether to apply percentile normalization
    factor_analys (bool): whether to apply factor analysis

  Returns:
    list_of_time_step_max (list): best stopping time for each compound
    cross_validated_scores (dict): Contains cross-validated accuracy scores and
      confusion matrices.

  """
  list_of_time_step_max = []
  correct_nsc = collections.defaultdict(list)
  mismatch_nsc = collections.defaultdict(list)

  correct_nscb = collections.defaultdict(list)
  mismatch_nscb = collections.defaultdict(list)

  emb_df_valid = transform.drop_unevaluated_comp(emb_df_clean)
  match_metadata_values = sorted(emb_df_valid.index.get_level_values(
      level=metadata.MOA).unique())
  num_moa = len(match_metadata_values)

  if report_confusion_matrix:
    confusion_matrices_nsc = collections.defaultdict(list)
    confusion_matrices_nscb = collections.defaultdict(list)
    for k in range(1, 5):
      confusion_matrices_nsc[k] = np.zeros((num_moa, num_moa))
      confusion_matrices_nscb[k] = np.zeros((num_moa, num_moa))
  else:
    confusion_matrices_nsc = None
    confusion_matrices_nscb = None

  dist_at_time = {}

  all_compounds_valid = emb_df_valid.index.get_level_values(
      level=metadata.COMPOUND)
  for i in range(n_comp):

    print("cross-validation for compound %s" %i)

    comp_set = list_of_comp_set[i]

    ## dataframe excluding the left-out compound
    emb_df_train = emb_df_valid[all_compounds_valid.isin(comp_set["a"])]
    if "treatment_group" not in emb_df_train.index.names:
      raise ValueError("Must have treatment_group in embeddings index names.")

    ## best time step for a given left-out compound
    ## as far as speed, this would be a significant bottleneck,
    ## since it has to evaluate at all timesteps

    time_step_max = find_time_step_max(emb_df_train, contents, steps)
    # time_step_max = 20000  ## Used for testing purposes
    list_of_time_step_max.append(time_step_max)

    if time_step_max in dist_at_time:
      ## Cache dist matrix at given time.
      dist = dist_at_time[time_step_max]
    else:
      ## find cosine distances given left-out compound at time_step_max
      means = transform_and_means(contents, emb_df_clean, time_step_max,
                                  percent_norm=percent_norm,
                                  factor_analys=factor_analys)
      means_valid = transform.drop_unevaluated_comp(means)
      dist = distance_analysis.matrix(distance.cosine, means_valid)
      dist_at_time[time_step_max] = dist

    # k-NN up to k=4
    for k in range(1, 5):
      update_stats_new_compound(comp_set, dist, k,
                                evaluate.not_same_compound_filter,
                                correct_nsc, mismatch_nsc,
                                match_metadata_values,
                                confusion_matrices_nsc)

      update_stats_new_compound(comp_set, dist, k,
                                evaluate.not_same_compound_or_batch_filter,
                                correct_nscb, mismatch_nscb,
                                match_metadata_values,
                                confusion_matrices_nscb)

  ## obtain accuracies from correct and mismatched, for cross validated scores.
  acc_nsc = calculate_moa_accuracy(correct_nsc, mismatch_nsc)
  acc_nscb = calculate_moa_accuracy(correct_nscb, mismatch_nscb)

  cross_validated_scores = {
      "acc_nsc": acc_nsc,
      "acc_nscb": acc_nscb
  }

  if report_confusion_matrix:
    cross_validated_scores.update({
        "confusion_matrices_nsc": confusion_matrices_nsc,
        "confusion_matrices_nscb": confusion_matrices_nscb
    })
  return (list_of_time_step_max, cross_validated_scores)


def calculate_moa_accuracy(correct, mismatch):
  """Calculate MOA accuracy from correct and mismatch.

  Args:
    correct (defaultdict): correct compounds
    mismatch (defaultdict): mismatched compounds

  Returns:
    acc (list): accuracy for each k used.
  """
  acc = []
  for k in range(1, 5):
    acc.append(len(correct[k]) / (len(correct[k]) + len(mismatch[k])))
  return np.round(acc, 3) * 100.0


def elementwise_stats(knns):
  """Compute elementwise knn.

  Args:
    knns (list of dicts):
      The keys are which statistic is considered
      (e.g. 'Accuracy NSC').
      Each value should itself be a dict, whose keys
      correspond to k-NN (indexing starts at zero.)
  Returns:
    stats_dict (dict): dictionary containing information for each overlapping
      category and value of k.
  """
  stats_dict = {}
  common_keys = [set(knn.keys()) for knn in knns]
  for key in set.intersection(*common_keys):
    stats_dict[key] = {}
    common_ks = [set(knn[key].keys()) for knn in knns]
    for k in set.intersection(*common_ks):
      stats_dict[key][k] = [knn[key][k] for knn in knns]
  return stats_dict


def knn_bootstrap(emb_df, num_bootstrap=2, seed=SEED, percent_norm=False,
                  factor_analys=False):
  """Generate bootstrap statistics.

  Args:
    emb_df (pandas dataframe): dataframe to use (includes controls)
    num_bootstrap (int): number of bootstrap reps
    seed (int): which seed value to use for bootstrapping
    percent_norm (bool): whether to apply percentile normalization
    factor_analys (bool): whether to apply factor analysis
  Returns:
    stats_dict (dict): dictionary containing mean and
      standard deviation information for each overlapping
      category and value of k.
  """
  boot_knns = []
  boot_clustering_scores = []
  np.random.seed(seed=seed)
  for _ in range(num_bootstrap):
    boot_emb = transform.get_bootstrap_sample(emb_df)
    boot_post_proc = apply_post_processing(boot_emb,
                                           percent_norm=percent_norm,
                                           factor_analys=factor_analys)
    boot_means = transform.drop_unevaluated_comp(boot_post_proc).groupby(level=[
        metadata.MOA, metadata.COMPOUND,
        metadata.CONCENTRATION, metadata.BATCH,
        metadata.TREATMENT_GROUP]).mean()
    scores = get_scores_from_means(boot_means, report_confusion_matrix=False)
    boot_knns.append(scores["knn"])
    boot_clustering_scores.append(scores["clustering_score"])

  knn_return = {"knn_scores": elementwise_stats(boot_knns),
                "clustering_scores": boot_clustering_scores}
  return knn_return


def cross_val_knn_bootstrap(emb_df, contents, steps, list_of_comp_set,
                            num_bootstrap=2, seed=SEED,
                            percent_norm=False,
                            factor_analys=False):
  """Generate bootstrap statistics.

  Args:
    emb_df (pandas dataframe): dataframe to use
    contents (dict): Contents from Wasserstein training routine
    steps (int): List of timesteps at which timestep to evaluate
    list_of_comp_set (list): each element is a dict for cross-validation.
    num_bootstrap (int): number of bootstrap reps
    seed (int): which seed value to use for bootstrapping
    percent_norm (bool): whether to apply percentile normalization
    factor_analys (bool): whether to apply factor analysis

  Returns:
    cross_validated_scores (list): list of cross-validated knn scores for each
      bootstrap sample.
  """
  cross_validated_scores = []
  n_comp = len(list_of_comp_set)
  np.random.seed(seed=seed)
  for _ in range(num_bootstrap):
    boot_emb = transform.get_bootstrap_sample(emb_df)
    cross_val = cross_val_train(boot_emb, contents, steps,
                                list_of_comp_set, n_comp,
                                report_confusion_matrix=False,
                                percent_norm=percent_norm,
                                factor_analys=factor_analys)
    cross_validated_scores.append(cross_val)
  return cross_validated_scores


def get_scores_from_means(means, report_knn=True, report_confusion_matrix=True):
  """Get confusion matrices, accuracy scores, and clustering score.

  Args:
    means (pandas dataframe): means for each treatment.
    report_knn (boolean): whether or not to compute KNN scores.
    report_confusion_matrix (boolean): whether or not to include confusion
      matrix.
  Returns:
    dict containing the following:
      confusion_matrix: contains confusion matrices for nsc and nscb and k=1...4
      knn_df_dict (dict): contains accuracy scores for nsc and nscb and k=1...4
      clustering_score (float):
  """
  moa_name_index = get_index_for_name(means, "moa")
  dist = distance_analysis.matrix(distance.cosine, means)
  clustering_score = metrics.silhouette_score(
      dist,
      labels=means.index.get_level_values(level=metadata.MOA),
      metric="precomputed")
  output_dict = {"clustering_score": clustering_score}

  if report_knn:
    knn_df = evaluate.make_knn_moa_dataframe(means)
    output_dict.update({"knn": knn_df.to_dict()})

  if report_confusion_matrix:
    confusion_matrix = {"nsc": {}, "nscb": {}}
    for k in range(1, 5):
      confusion_matrix["nsc"][k] = confusion_matrix_from_dist(
          dist, k, evaluate.not_same_compound_filter,
          dist.index.levels[moa_name_index])
      confusion_matrix["nscb"][k] = confusion_matrix_from_dist(
          dist, k, evaluate.not_same_compound_or_batch_filter,
          dist.index.levels[moa_name_index])
    output_dict.update({"confusion_matrix": confusion_matrix})
  return output_dict


def get_batch_classification_scores(emb_df_clean, seed=SEED):
  """Batch classification scores for logistics regression and random forests.

  Args:
    emb_df_clean (pandas dataframe): input dataframe
    seed (int): seed to use for the random state
  Returns:
    batch_classification_scores (dict): contains mean and standard deviations
      of logistic regression and random forest batch classifiers.

  """
  lr = lm.LogisticRegression(random_state=seed)
  rf = ensemble.RandomForestClassifier(n_estimators=100, random_state=seed)
  batch_classification_scores = {
      "logistic regression":
      evaluate.make_batch_classifier_score_dataframe(emb_df_clean,
                                                     lr).to_dict(),
      "random forest":
      evaluate.make_batch_classifier_score_dataframe(emb_df_clean,
                                                     rf).to_dict()
  }
  return batch_classification_scores


def test_tvn(emb_df_clean, num_bootstrap=2, percent_norm=False,
             factor_analys=False):
  """test set (TVN only).

  Args:
    emb_df_clean (pandas dataframe): input embeddings.
    num_bootstrap (int): number of bootstrap samples to use
    percent_norm (bool): whether to apply percentile normalization
    factor_analys (bool): whether to apply factor analysis
  Returns:
    metrics_dict_tvn (dict): contains batch_classification_scores and moa_scores
  """
  emb_df_post = apply_post_processing(emb_df_clean, percent_norm, factor_analys)
  means = transform.drop_unevaluated_comp(emb_df_post).groupby(level=[
      metadata.MOA, metadata.COMPOUND, metadata.CONCENTRATION, metadata.BATCH,
      metadata.TREATMENT_GROUP
  ]).mean()
  batch_classification_scores = get_batch_classification_scores(emb_df_post)
  moa_scores = get_scores_from_means(means)
  bootstrap_scores = knn_bootstrap(emb_df_clean, num_bootstrap=num_bootstrap,
                                   percent_norm=percent_norm,
                                   factor_analys=factor_analys)

  return_dict = {
      "batch_classification_scores": batch_classification_scores,
      "bootstrap_scores": bootstrap_scores
  }
  return_dict.update(moa_scores)
  return return_dict


def test_wdn(emb_df_clean, contents, list_of_time_step_max, steps,
             list_of_comp_set, num_bootstrap=2, percent_norm=False,
             factor_analys=False):
  """test set (WDN).

  Args:
    emb_df_clean (pandas dataframe): input dataframe
    contents (dict): Contents from Wasserstein training routine
    list_of_time_step_max (int): List of timesteps at which timestep to evaluate
      WDN statistics. For example, could be the time step where average nsc and
      nscb for k=1...4 is maximized for a given compound in the cross-validation
      procedure.
    steps (list): all timesteps from analysis, used for bootstrapping.
    list_of_comp_set (list): each element is a dict for cross-validation.
    num_bootstrap (int): number of bootstrap samples to use
    percent_norm (bool): whether to apply percentile normalization
    factor_analys (bool): whether to apply factor analysis

  Returns:
    metrics_dict_wdn (dict): contains batch_classification_scores and moa_scores
  """

  batch_classification_scores = {}
  clustering_scores = {}

  ##  We do not do cross validation for batch classification and Silhouette
  ##  scores. For the BBBC021 dataset, batch classification only applies to
  ##  controls, so the result of leave-one-out cross validation are the same as
  ##  taking the weighted average/standard deviation across left-out compounds.
  ##  For the Silhouette score, it is possible to do leave-one-out cross
  ##  validation, but then we would also have to do it for TVN and CORAL for
  ##  each left-out compound as well.
  unique_time_step_max = list(set(list_of_time_step_max))
  for time_step_max in unique_time_step_max:

    ## We need both the transformed embeddings as well as the means, so we do
    ## not use transform_and_means here.
    emb_df_trans = wasserstein_transform(contents, emb_df_clean, time_step_max)
    df_post_processed = apply_post_processing(emb_df_trans, percent_norm,
                                              factor_analys)
    means = transform.drop_unevaluated_comp(df_post_processed.groupby(level=[
        metadata.MOA, metadata.COMPOUND, metadata.CONCENTRATION, metadata.BATCH,
        metadata.TREATMENT_GROUP
    ]).mean())

    batch_class_at_time = get_batch_classification_scores(df_post_processed)
    batch_classification_scores[time_step_max] = batch_class_at_time
    moa_at_time = get_scores_from_means(means, report_knn=False,
                                        report_confusion_matrix=False)
    clustering_score = moa_at_time["clustering_score"]
    clustering_scores[time_step_max] = clustering_score

  knn_bootstrap_scores = cross_val_knn_bootstrap(emb_df_clean, contents, steps,
                                                 list_of_comp_set,
                                                 num_bootstrap=num_bootstrap,
                                                 percent_norm=percent_norm,
                                                 factor_analys=factor_analys)
  return_dict = {
      "batch_classification_scores": batch_classification_scores,
      "knn_bootstrap_scores": knn_bootstrap_scores,
      "clustering_scores": clustering_scores
  }
  return return_dict


def test_coral(emb_df_clean, num_bootstrap=2, percent_norm=False,
               factor_analys=False):
  """test set (Mike's CORAL).

  Args:
    emb_df_clean (pandas dataframe): input dataframe
    num_bootstrap (int): number of bootstrap samples to use
    percent_norm (bool): whether to apply percentile normalization
    factor_analys (bool): whether to apply factor analysis

  Returns:
    metrics_dict_coral (dict): contains batch_classification_scores and
      moa_scores
  """
  emb_df_test_coral_mike = transform.coral_without_mean_shift_batch(
      emb_df_clean)
  emb_df_post = apply_post_processing(emb_df_test_coral_mike,
                                      percent_norm, factor_analys)
  means = transform.drop_unevaluated_comp(emb_df_post).groupby(level=[
      metadata.MOA, metadata.COMPOUND, metadata.CONCENTRATION, metadata.BATCH,
      metadata.TREATMENT_GROUP
  ]).mean()
  batch_classification_scores = get_batch_classification_scores(
      emb_df_post)
  moa_scores = get_scores_from_means(means)
  bootstrap_scores = knn_bootstrap(emb_df_test_coral_mike,
                                   num_bootstrap=num_bootstrap,
                                   percent_norm=percent_norm,
                                   factor_analys=factor_analys)
  return_dict = {
      "batch_classification_scores": batch_classification_scores,
      "bootstrap_scores": bootstrap_scores
  }
  return_dict.update(moa_scores)
  return return_dict


def evaluate_metrics(contents, emb_df_clean, list_of_time_step_max, steps,
                     list_of_comp_set, num_bootstrap=2, percent_norm=False,
                     factor_analys=False):
  """Test MOA accuracy for NSC and NSCB for several different methods.

  Methods tested are TVN, WDN, and CORAL.

  Args:
    contents (dict): Contents from Wasserstein training routine
    emb_df_clean (pandas dataframe): input dataframe
    list_of_time_step_max (list): list of all timesteps to use for wdn
    steps (list): all timesteps from analysis, used for bootstrapping.
    list_of_comp_set (list): dictionaries for each compound for leave-one-out
    num_bootstrap (int): number of bootstrap samples to use
    percent_norm (bool): whether to apply percentile normalization
    factor_analys (bool): whether to apply factor analysis

  Returns:
    metrics_dict (dict): contains metrics for each method.

  """
  metrics_dict = {}
  metrics_dict["tvn"] = test_tvn(emb_df_clean, num_bootstrap=num_bootstrap,
                                 percent_norm=percent_norm,
                                 factor_analys=factor_analys)
  metrics_dict["coral"] = test_coral(emb_df_clean, num_bootstrap=num_bootstrap,
                                     percent_norm=percent_norm,
                                     factor_analys=factor_analys)
  metrics_dict["wdn"] = test_wdn(emb_df_clean, contents, list_of_time_step_max,
                                 steps, list_of_comp_set,
                                 num_bootstrap=num_bootstrap,
                                 percent_norm=percent_norm,
                                 factor_analys=factor_analys)
  return metrics_dict


def main(argv):
  del argv

  emb_df_clean = io_utils.read_dataframe_from_hdf5(FLAGS.input_df)
  if "treatment_group" not in emb_df_clean.index.names:
    raise ValueError("Must have treatment_group in embeddings index names.")
  contents = load_contents(FLAGS.transformation_file)

  ## dictionary to save things
  save_dict = {}

  ## Get steps over training
  steps = list(contents.keys())
  steps.remove("params")
  steps = np.sort(steps)

  ## Truncate list of steps
  steps = steps[:FLAGS.num_steps]

  ## embeddings without unevaluated compound
  emb_df_valid = transform.drop_unevaluated_comp(emb_df_clean)
  if "treatment_group" not in emb_df_valid.index.names:
    raise ValueError("Must have treatment_group in embeddings index names.")

  ## list of compounds and number of compounds
  comp_list = emb_df_valid.index.get_level_values(
      level=metadata.COMPOUND).unique()
  n_comp = len(comp_list)

  ## Set up data structure for leave-one-out cross validation
  list_of_comp_set = []
  for i in range(n_comp):
    comp_set = {}
    comp_set["b"] = comp_list[i]
    comp_set["a"] = list(set(comp_list).difference([comp_list[i]]))
    list_of_comp_set.append(comp_set)

  ## Cross validation training with leave-one-out and variable stopping time.
  (steps_max, cross_validated_scores) = cross_val_train(
      emb_df_clean, contents, steps, list_of_comp_set, n_comp,
      percent_norm=FLAGS.percentile_normalize,
      factor_analys=FLAGS.factor_analysis)

  ## Find first and last timesteps used, to use for bootstraps
  boot_steps = [steps[i] for i, v in enumerate(steps) if
                np.max(steps_max) >= v >= np.min(steps_max)]

  metrics_dict = evaluate_metrics(contents, emb_df_clean, steps_max,
                                  boot_steps, list_of_comp_set,
                                  num_bootstrap=FLAGS.num_bootstrap,
                                  percent_norm=FLAGS.percentile_normalize,
                                  factor_analys=FLAGS.factor_analysis)
  save_dict["metrics_dict"] = metrics_dict

  ## time steps where max cross validation results were found
  save_dict["list_of_time_step_max"] = steps_max

  ## accuracy for not same compound or batch, obtained at time_step_max
  ## for each individual compound.
  save_dict["metrics_dict"]["wdn"]["cross_val_scores"] = cross_validated_scores

  with gfile.GFile(FLAGS.output_file, mode="w") as f:
    f.write(pickle.dumps(save_dict))


if __name__ == "__main__":
  app.run(main)
